{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/airaria/TextBrewer/blob/master/examples/notebook_examples/msra_ner.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Kh5ekwgrCG4"
      },
      "source": [
        "This notebook shows how to fine-tune a model on msra_ner datasets and how to distill the model with TextBrewer.\n",
        "\n",
        "Detailed Docs can be find here:\n",
        "https://github.com/airaria/TextBrewer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xi_1qOral-LE"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "device='cuda' if torch.cuda.is_available() else 'cpu'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4cahHLZXmI81"
      },
      "outputs": [],
      "source": [
        "!pip install datasets\n",
        "!pip install transformers\n",
        "!pip install seqeval\n",
        "!pip install textbrewer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lfq_aPKhe9jJ"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import torch\n",
        "from transformers import BertForSequenceClassification, BertTokenizer,BertConfig,BertForTokenClassification\n",
        "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
        "from transformers import Trainer, TrainingArguments\n",
        "from transformers import pipeline\n",
        "from datasets import load_dataset,load_metric"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NvnDEfRewRhJ"
      },
      "source": [
        "### Prepare dataset to train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZtoFe4FPmWh_"
      },
      "outputs": [],
      "source": [
        "task = \"ner\" #  \"ner\", \"pos\" or \"chunk\"\n",
        "model_checkpoint = \"bert-base-chinese\"\n",
        "batch_size = 8"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OhvaZdRqmow6"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset, load_metric\n",
        "datasets = load_dataset(\"msra_ner\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WgM6aADvnIm5"
      },
      "outputs": [],
      "source": [
        "label_list = datasets[\"train\"].features[f\"{task}_tags\"].feature.names\n",
        "label_list"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vO48GEqlnRuh"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6fO7iYQanY1G"
      },
      "outputs": [],
      "source": [
        "def tokenize_and_align_labels(examples):\n",
        "    tokenized_inputs = tokenizer(examples[\"tokens\"], truncation=True, is_split_into_words=True)\n",
        "\n",
        "    labels = []\n",
        "    for i, label in enumerate(examples[f\"{task}_tags\"]):\n",
        "        word_ids = tokenized_inputs.word_ids(batch_index=i)\n",
        "        previous_word_idx = None\n",
        "        label_ids = []\n",
        "        for word_idx in word_ids:\n",
        "            # Special tokens have a word id that is None. We set the label to -100 so they are automatically\n",
        "            # ignored in the loss function.\n",
        "            if word_idx is None:\n",
        "                label_ids.append(-100)\n",
        "            # We set the label for the first token of each word.\n",
        "            elif word_idx != previous_word_idx:\n",
        "                label_ids.append(label[word_idx])\n",
        "            # For the other tokens in a word, we set the label to either the current label or -100, depending on\n",
        "            # the label_all_tokens flag.\n",
        "            else:\n",
        "                label_ids.append(label[word_idx] if label_all_tokens else -100)\n",
        "            previous_word_idx = word_idx\n",
        "\n",
        "        labels.append(label_ids)\n",
        "\n",
        "    tokenized_inputs[\"labels\"] = labels\n",
        "    return tokenized_inputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n9iuLfNDoqz6"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "def compute_metrics(p):\n",
        "    print(p.__dict__)\n",
        "    predictions = p.predictions\n",
        "    labels = p.label_ids\n",
        "    predictions = np.argmax(predictions, axis=2)\n",
        "\n",
        "    # Remove ignored index (special tokens)\n",
        "    true_predictions = [\n",
        "        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
        "        for prediction, label in zip(predictions, labels)\n",
        "    ]\n",
        "    true_labels = [\n",
        "        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
        "        for prediction, label in zip(predictions, labels)\n",
        "    ]\n",
        "\n",
        "    results = metric.compute(predictions=true_predictions, references=true_labels)\n",
        "    return {\n",
        "        \"precision\": results[\"overall_precision\"],\n",
        "        \"recall\": results[\"overall_recall\"],\n",
        "        \"f1\": results[\"overall_f1\"],\n",
        "        \"accuracy\": results[\"overall_accuracy\"],\n",
        "    }\n",
        "\n",
        "def compute_eval_metrics(p):\n",
        "    print(p.__dict__)\n",
        "    predictions = p.predictions[0]\n",
        "    labels = p.label_ids\n",
        "    predictions = np.argmax(predictions, axis=2)\n",
        "\n",
        "    # Remove ignored index (special tokens)\n",
        "    true_predictions = [\n",
        "        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
        "        for prediction, label in zip(predictions, labels)\n",
        "    ]\n",
        "    true_labels = [\n",
        "        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
        "        for prediction, label in zip(predictions, labels)\n",
        "    ]\n",
        "\n",
        "    results = metric.compute(predictions=true_predictions, references=true_labels)\n",
        "    return {\n",
        "        \"precision\": results[\"overall_precision\"],\n",
        "        \"recall\": results[\"overall_recall\"],\n",
        "        \"f1\": results[\"overall_f1\"],\n",
        "        \"accuracy\": results[\"overall_accuracy\"],\n",
        "    }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S9X0l225n7WZ"
      },
      "outputs": [],
      "source": [
        "from transformers import BertForTokenClassification, TrainingArguments, Trainer\n",
        "\n",
        "model = BertForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))\n",
        "model.to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "87HPRyytoYb8"
      },
      "outputs": [],
      "source": [
        "from transformers import DataCollatorForTokenClassification\n",
        "data_collator = DataCollatorForTokenClassification(tokenizer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "97wY0abnobYe"
      },
      "outputs": [],
      "source": [
        "metric = load_metric(\"seqeval\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xxl_04jup6ep"
      },
      "outputs": [],
      "source": [
        "tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ENBYA77ToLwV"
      },
      "outputs": [],
      "source": [
        "args = TrainingArguments(\n",
        "    f\"test-{task}\",\n",
        "    evaluation_strategy = \"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    per_device_train_batch_size=batch_size,\n",
        "    per_device_eval_batch_size=batch_size,\n",
        "    num_train_epochs=2,\n",
        "    weight_decay=0.01,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qaj3gN7xovJ0"
      },
      "outputs": [],
      "source": [
        "trainer = Trainer(\n",
        "    model,\n",
        "    args,\n",
        "    train_dataset=tokenized_datasets[\"train\"],\n",
        "    eval_dataset=tokenized_datasets[\"test\"],\n",
        "    data_collator=data_collator,\n",
        "    tokenizer=tokenizer,\n",
        "    compute_metrics=compute_metrics\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NzxXm1K1o_Uh"
      },
      "outputs": [],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "metadata": {
        "id": "XxdTV3biv2o1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wRiDIS5WuJ7s"
      },
      "outputs": [],
      "source": [
        "torch.save(model.state_dict(), '/content/drive/MyDrive/msra_teacher_model.pt') #save the teacher model weights to distill"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e5fK_phaslqc"
      },
      "source": [
        "### Start distiilation"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from torch.utils.data import DataLoader, RandomSampler\n",
        "train_dataset=tokenized_datasets[\"train\"].remove_columns(['id','tokens','ner_tags'])\n",
        "train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=32,collate_fn=data_collator) #prepare dataloader"
      ],
      "metadata": {
        "id": "eLs-39XvLCzp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z03IrKzzXYMZ"
      },
      "outputs": [],
      "source": [
        "import textbrewer\n",
        "from textbrewer import GeneralDistiller\n",
        "from textbrewer import TrainingConfig, DistillationConfig\n",
        "from transformers import BertForTokenClassification, BertConfig,BertTokenizer\n",
        "from transformers import get_linear_schedule_with_warmup\n",
        "import torch \n",
        "from torch.optim import AdamW"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YloYaVoWs14s"
      },
      "source": [
        "Initialize the student model by BertConfig and prepare the teacher model.\n",
        "\n",
        "bert_config_L3.json refer to a 3-layer Bert."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_P7a1W9NXdH9"
      },
      "outputs": [],
      "source": [
        "bert_config_T3 = BertConfig.from_json_file('/content/bert_config_L3.json') \n",
        "bert_config_T3.output_hidden_states = True\n",
        "bert_config_T3.num_labels = len(label_list)\n",
        "\n",
        "student_model = BertForTokenClassification(bert_config_T3)\n",
        "\n",
        "bert_config = BertConfig.from_json_file('/content/bert_config.json')\n",
        "bert_config.output_hidden_states = True\n",
        "bert_config.num_labels = len(label_list)\n",
        "\n",
        "teacher_model = BertForTokenClassification(bert_config) \n",
        "teacher_model.load_state_dict(torch.load('/content/drive/MyDrive/msra_teacher_model.pt'))\n",
        "\n",
        "teacher_model.to(device)\n",
        "student_model.to(device)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MKC_jnkuxTJn"
      },
      "source": [
        "The cell below is to distill the teacher model to student model you prepared.\n",
        "\n",
        "After the code execution is complete, the distilled model will be in 'saved_model' in colab file list"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def proc_fn(batch):\n",
        "  return {'input_ids':batch['input_ids'],'token_type_ids':batch['token_type_ids'],'attention_mask':batch['attention_mask'],'labels':batch['labels']}"
      ],
      "metadata": {
        "id": "DLXYbGlb55Zm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C6K8OyXuXgsh"
      },
      "outputs": [],
      "source": [
        "num_epochs = 20\n",
        "num_training_steps = len(train_dataloader) * num_epochs\n",
        "\n",
        "optimizer = AdamW(student_model.parameters(), lr=1e-5)\n",
        "\n",
        "scheduler_class = get_linear_schedule_with_warmup\n",
        "\n",
        "scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}\n",
        "\n",
        "def simple_adaptor(batch, model_outputs):\n",
        "  return {\"logits\":model_outputs.logits, 'hidden': model_outputs.hidden_states}\n",
        "\n",
        "distill_config = DistillationConfig(\n",
        "    intermediate_matches=[{\"layer_T\":0, \"layer_S\":0, \"feature\":\"hidden\", \"loss\":\"hidden_mse\", \"weight\":1},\n",
        "               {\"layer_T\":4, \"layer_S\":1, \"feature\":\"hidden\", \"loss\":\"hidden_mse\", \"weight\":1},\n",
        "               {\"layer_T\":8, \"layer_S\":2, \"feature\":\"hidden\", \"loss\":\"hidden_mse\", \"weight\":1},\n",
        "               {\"layer_T\":12,\"layer_S\":3, \"feature\":\"hidden\", \"loss\":\"hidden_mse\", \"weight\":1}])\n",
        "\n",
        "train_config = TrainingConfig()\n",
        "distiller = GeneralDistiller(\n",
        "    train_config=train_config, distill_config=distill_config,\n",
        "    model_T=teacher_model, model_S=student_model, \n",
        "    adaptor_T=simple_adaptor, adaptor_S=simple_adaptor)\n",
        "\n",
        "\n",
        "with distiller:\n",
        "    distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None,\n",
        "                    batch_postprocessor=proc_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BmTWyy9dyuqG"
      },
      "source": [
        "Then evaluate the distilled model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WTU4sYE-fKN2"
      },
      "outputs": [],
      "source": [
        "bert_config_T3 = BertConfig.from_json_file('/content/bert_config_L3.json')\n",
        "\n",
        "bert_config_T3.output_hidden_states = True\n",
        "bert_config_T3.num_labels = len(label_list)\n",
        "test_model = BertForTokenClassification(bert_config_T3)\n",
        "\n",
        "\n",
        "test_model.load_state_dict(torch.load('/content/drive/MyDrive/model/gs2813.pkl'))\n",
        "test_model.to(device)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2pc2iuCAZmbh"
      },
      "outputs": [],
      "source": [
        "args = TrainingArguments(\n",
        "    f\"distill-test\",\n",
        "    evaluation_strategy = \"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    per_device_train_batch_size=batch_size,\n",
        "    per_device_eval_batch_size=batch_size,\n",
        "    do_train=False,\n",
        "    do_eval=True,\n",
        "    no_cuda=False,\n",
        "    num_train_epochs=2,\n",
        "    weight_decay=0.01,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m3VwmhxvZUDR"
      },
      "outputs": [],
      "source": [
        "trainer = Trainer(\n",
        "    test_model,\n",
        "    args,\n",
        "    train_dataset=tokenized_datasets[\"train\"],\n",
        "    eval_dataset=tokenized_datasets[\"test\"],\n",
        "    data_collator=data_collator,\n",
        "    tokenizer=tokenizer,\n",
        "    compute_metrics=compute_eval_metrics\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "32q6n0fGzjtX"
      },
      "outputs": [],
      "source": [
        "trainer.evaluate()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "msra.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}